| Trait | Transplanted | WL Mortality | P-value |
|---|---|---|---|
| Gender (% Female) | 44.1% (N=4983) | 43.1% (N=540) | 0.709 |
| Age | 4.0 [0.0, 10.5] (N=4983) | 0.0 [0.0, 1.6] (N=540) | 0.000 |
| Weight | 15.0 [0.0, 33.5] (N=4982) | 7.6 [2.4, 12.7] (N=540) | 0.000 |
| Height | 99.0 [55.0, 143.0] (N=4980) | 67.0 [45.6, 88.4] (N=540) | 0.000 |
| BMI | 16.2 [13.7, 18.6] (N=4980) | 15.4 [13.6, 17.2] (N=540) | 0.000 |
| Blood Type A | 36.0% (N=4983) | 33.9% (N=540) | 0.345 |
| Blood Type B | 13.6% (N=4983) | 11.5% (N=540) | 0.190 |
| Blood Type AB | 4.1% (N=4983) | 2.2% (N=540) | 0.042 |
| Blood Type O | 46.2% (N=4983) | 52.4% (N=540) | 0.007 |
| Race White | 53.0% (N=4983) | 50.4% (N=540) | 0.268 |
| Race Black | 20.0% (N=4983) | 23.0% (N=540) | 0.123 |
| Race Other | 3.1% (N=4983) | 3.0% (N=540) | 1.000 |
| VAD % | 13.3% (N=4983) | 8.5% (N=540) | 0.002 |
| eGFR | 96.1 [72.3, 120.0] (N=4966) | 82.2 [54.2, 110.2] (N=540) | 0.000 |
| Albumin | 3.6 [3.1, 4.1] (N=4836) | 3.3 [2.8, 3.8] (N=518) | 0.000 |
| Dialysis % | 1.2% (N=4978) | 4.8% (N=540) | 0.000 |
| Ventilator % | 15.8% (N=4983) | 35.0% (N=540) | 0.000 |
| ECMO % | 4.6% (N=4983) | 14.6% (N=540) | 0.000 |
Survival Analysis: CatBoost
Demographics Table
Native CatBoost Model
Optuna Hyperparameter Optimization
- Best is trial #12/50 with value: 0.74
Show the code
model_params_native = {
'learning_rate': 0.15,
'depth': 8,
'colsample_bylevel': 0.75,
'min_data_in_leaf': 95,
'l2_leaf_reg': 10.54
}
Model
Show the code
#| eval: true
#| echo: false
#| message: false
#| warning: false
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
import pandas as pd
import optuna
from catboost import CatBoostClassifier, Pool
model_auc = CatBoostClassifier(objective='Logloss',
iterations=1000,
eval_metric="AUC",
**model_params_native,
boosting_type='Ordered',
bootstrap_type='MVS',
metric_period=25,
early_stopping_rounds=100,
use_best_model=False,
random_seed=1997)
# Create a Pool object for the training and testing data
train_pool = Pool(X_train, cat_features=cat_index, label=Y_train)
test_pool = Pool(X_test, cat_features=cat_index, label=Y_test)
model_auc.fit(train_pool, eval_set=test_pool)Warning: Overfitting detector is active, thus evaluation metric is calculated on every iteration. 'metric_period' is ignored for evaluation metric.
0: test: 0.6234443 best: 0.6234443 (0) total: 156ms remaining: 2m 35s
25: test: 0.7215217 best: 0.7221167 (24) total: 1.02s remaining: 38.4s
50: test: 0.7234322 best: 0.7259924 (37) total: 1.98s remaining: 36.8s
75: test: 0.7192891 best: 0.7259924 (37) total: 3.27s remaining: 39.7s
100: test: 0.7195511 best: 0.7259924 (37) total: 4.74s remaining: 42.2s
125: test: 0.7081423 best: 0.7259924 (37) total: 6.23s remaining: 43.2s
Stopped by overfitting detector (100 iterations wait)
bestTest = 0.7259924014
bestIteration = 37
<catboost.core.CatBoostClassifier object at 0x000002DFA58C3EC0>
Calibration Plot
Calibrated Model Metrics
Model AUC ... False Negative (FN) True Positive (TP)
0 Native_Catboost 0.707553 ... 35 67
[1 rows x 13 columns]
SHAP Feature Importance
Show the code
# Retrieve the ranked comparison dataframe from Python
shap_df <- py$shap_df
# Convert the pandas dataframe to an R tibble
shap_tbl <- as_tibble(shap_df) %>%
mutate(SHAP_Importance = abs(Importance_SHAP)) %>%
arrange(-SHAP_Importance)
# Create a formatted table using huxtable, including the ranks for each method and the 'Direction' column
shap_table <- shap_tbl %>%
rowid_to_column(var = "Overall Rank") %>%
select('Overall Rank', 'Feature Id',
'Importance_SHAP', 'SHAP_Direction',
'SHAP_Importance'
)
shap_table %>%
DT::datatable(
rownames = FALSE,
options = list(
columnDefs = list(
list(className = 'dt-center', targets = "_all")
)
)
)CatBoost - One Hot Encoding (Hybrid)
- Train
[1] "Gender" "Race" "Blood_Type" "VAD_TCR" "WL_Oth_Org"
[6] "Cereb_Vasc" "Diabetes" "Diag_Code" "XMatch_Req" "List_Ctr"
[1] "Race" "Blood_Type" "Diag_Code"
'data.frame': 4523 obs. of 44 variables:
$ outcome : int 0 0 0 1 0 1 0 1 1 1 ...
$ Age : num 1 0 0 0 2 0 1 11 0 0 ...
..- attr(*, "label")= chr "WL AGE AT LISTING IN YEARS"
$ Gender : Factor w/ 2 levels "F","M": 1 2 2 2 1 2 1 1 1 2 ...
$ Weight : num 6.71 3.02 6.6 6 19.4 ...
$ Height : num 72 49.5 66 57 97.5 ...
$ BMI : num 12.9 12.3 15.2 18.5 20.4 ...
$ BSA : num 0.366 0.204 0.348 0.308 0.725 ...
$ PGE_TCR : num 0 0 0 0 0 0 0 0 0 0 ...
..- attr(*, "label")= chr "TCR PGE AT LISTING"
$ ECMO_Reg : num 0 0 0 1 0 0 0 0 0 0 ...
..- attr(*, "label")= chr "TCR ECMO AT LISTING"
$ VAD : num -1 -1 -1 -1 -1 -1 -1 1 -1 1 ...
$ VAD_TCR : Factor w/ 5 levels "LVAD","LVAD+RVAD",..: 3 3 3 3 3 3 3 1 3 2 ...
$ Vent_Reg : num 0 0 0 0 0 0 0 0 0 1 ...
..- attr(*, "label")= chr "TCR LIFE SUPPORT VENTILATOR"
$ WL_Oth_Org : Factor w/ 2 levels "No","Yes": 2 1 1 2 1 1 1 1 1 1 ...
$ Cereb_Vasc : Factor w/ 3 levels "No","Unknown",..: 1 1 1 1 1 1 1 1 1 1 ...
$ Diabetes : Factor w/ 5 levels "None","Type I",..: 1 1 1 1 1 1 1 1 1 1 ...
$ Dialysis : num -1 -1 -1 1 -1 -1 -1 -1 -1 -1 ...
$ Inotrop : num 0 0 0 0 1 0 1 0 1 0 ...
..- attr(*, "label")= chr "TCR IV INOTROPES AT LISTING"
$ Creatinine : num 0.13 0.46 0.28 0.67 0.42 0.34 0.34 0.5 0.22 0.21 ...
..- attr(*, "label")= chr "TCR MOST RECENT CREAT."
$ eGFR : num 228.2 44.3 97.1 35.1 95.6 ...
..- attr(*, "label")= chr "TCR MOST RECENT CREAT."
$ Albumin : num 4.4 3.7 3 3.8 4.6 3.9 4.6 3.6 3.2 3.2 ...
..- attr(*, "label")= chr "TCR TOTAL SERUM ALBUMIN AT LISTING (Pre 1/1/2007 for adult)"
$ Prior_HRTX : num 15 2 13 18 13 15 30 11 27 16 ...
$ Med_Refusals : num 3 8 4 15 5 3 5 5 4 4 ...
$ Prop_Refusals : num 0.864 0.951 0.852 0.972 0.882 ...
$ XMatch_Req : Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
$ List_Yr : num 2020 2020 2020 2020 2020 2020 2020 2020 2020 2020 ...
$ Policy_Chg : num 1 1 1 1 1 1 1 1 1 1 ...
$ List_Ctr : Factor w/ 89 levels "ahbent","alfoba",..: 61 34 37 69 16 61 21 71 66 36 ...
$ Race.Asian : num 0 0 0 0 0 0 0 0 0 0 ...
$ Race.Black : num 0 0 1 0 0 0 0 0 0 0 ...
$ Race.Hispanic : num 0 0 0 0 0 1 0 0 0 0 ...
$ Race.Other : num 0 0 0 0 0 0 0 1 0 0 ...
$ Race.White : num 1 1 0 1 1 0 1 0 1 1 ...
$ Blood_Type.A : num 0 0 0 1 0 0 0 0 1 0 ...
$ Blood_Type.AB : num 0 0 0 0 0 0 0 0 0 0 ...
$ Blood_Type.B : num 0 0 0 0 0 0 0 0 0 0 ...
$ Blood_Type.O : num 1 1 1 0 1 1 1 1 0 1 ...
$ Diag_Code.Congenital Heart Disease With Surgery : num 1 1 1 1 1 0 0 0 0 0 ...
$ Diag_Code.Congenital Heart Disease Without Surgery: num 0 0 0 0 0 0 0 0 0 0 ...
$ Diag_Code.Dilated Cardiomyopathy : num 0 0 0 0 0 1 1 1 0 1 ...
$ Diag_Code.Hypertrophic Cardiomyopathy : num 0 0 0 0 0 0 0 0 1 0 ...
$ Diag_Code.Myocarditis : num 0 0 0 0 0 0 0 0 0 0 ...
$ Diag_Code.Other : num 0 0 0 0 0 0 0 0 0 0 ...
$ Diag_Code.Restrictive Cardiomyopathy : num 0 0 0 0 0 0 0 0 0 0 ...
$ Diag_Code.Valvular Heart Disease : num 0 0 0 0 0 0 0 0 0 0 ...
'data.frame': 4523 obs. of 44 variables:
$ outcome : int 0 0 0 1 0 1 0 1 1 1 ...
$ Age : num 1 0 0 0 2 0 1 11 0 0 ...
..- attr(*, "label")= chr "WL AGE AT LISTING IN YEARS"
$ Gender : Factor w/ 2 levels "F","M": 1 2 2 2 1 2 1 1 1 2 ...
$ Weight : num 6.71 3.02 6.6 6 19.4 ...
$ Height : num 72 49.5 66 57 97.5 ...
$ BMI : num 12.9 12.3 15.2 18.5 20.4 ...
$ BSA : num 0.366 0.204 0.348 0.308 0.725 ...
$ PGE_TCR : num 0 0 0 0 0 0 0 0 0 0 ...
..- attr(*, "label")= chr "TCR PGE AT LISTING"
$ ECMO_Reg : num 0 0 0 1 0 0 0 0 0 0 ...
..- attr(*, "label")= chr "TCR ECMO AT LISTING"
$ VAD : num -1 -1 -1 -1 -1 -1 -1 1 -1 1 ...
$ VAD_TCR : Factor w/ 5 levels "LVAD","LVAD+RVAD",..: 3 3 3 3 3 3 3 1 3 2 ...
$ Ventilator : num 0 0 0 0 0 0 0 0 0 1 ...
..- attr(*, "label")= chr "TCR LIFE SUPPORT VENTILATOR"
$ WL_Oth_Org : Factor w/ 2 levels "No","Yes": 2 1 1 2 1 1 1 1 1 1 ...
$ Cereb_Vasc : Factor w/ 3 levels "No","Unknown",..: 1 1 1 1 1 1 1 1 1 1 ...
$ Diabetes : Factor w/ 5 levels "None","Type I",..: 1 1 1 1 1 1 1 1 1 1 ...
$ Dialysis : num -1 -1 -1 1 -1 -1 -1 -1 -1 -1 ...
$ Inotrop : num 0 0 0 0 1 0 1 0 1 0 ...
..- attr(*, "label")= chr "TCR IV INOTROPES AT LISTING"
$ Creatinine : num 0.13 0.46 0.28 0.67 0.42 0.34 0.34 0.5 0.22 0.21 ...
..- attr(*, "label")= chr "TCR MOST RECENT CREAT."
$ eGFR : num 228.2 44.3 97.1 35.1 95.6 ...
..- attr(*, "label")= chr "TCR MOST RECENT CREAT."
$ Albumin : num 4.4 3.7 3 3.8 4.6 3.9 4.6 3.6 3.2 3.2 ...
..- attr(*, "label")= chr "TCR TOTAL SERUM ALBUMIN AT LISTING (Pre 1/1/2007 for adult)"
$ Txp_Volume : num 15 2 13 18 13 15 30 11 27 16 ...
$ Med_Refusals : num 3 8 4 15 5 3 5 5 4 4 ...
$ Prop_Refusals: num 0.864 0.951 0.852 0.972 0.882 ...
$ XMatch_Req : Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
$ List_Yr : num 2020 2020 2020 2020 2020 2020 2020 2020 2020 2020 ...
$ Policy_Chg : num 1 1 1 1 1 1 1 1 1 1 ...
$ Listing_Ctr : Factor w/ 89 levels "ahbent","alfoba",..: 61 34 37 69 16 61 21 71 66 36 ...
$ Race_Asian : num 0 0 0 0 0 0 0 0 0 0 ...
$ Race_Black : num 0 0 1 0 0 0 0 0 0 0 ...
$ Race_Hispanic: num 0 0 0 0 0 1 0 0 0 0 ...
$ Race_Other : num 0 0 0 0 0 0 0 1 0 0 ...
$ Race_White : num 1 1 0 1 1 0 1 0 1 1 ...
$ Blood_Type_A : num 0 0 0 1 0 0 0 0 1 0 ...
$ Blood_Type_AB: num 0 0 0 0 0 0 0 0 0 0 ...
$ Blood_Type_B : num 0 0 0 0 0 0 0 0 0 0 ...
$ Blood_Type_O : num 1 1 1 0 1 1 1 1 0 1 ...
$ CHD_Surgery : num 1 1 1 1 1 0 0 0 0 0 ...
$ CHD_NoSurgery: num 0 0 0 0 0 0 0 0 0 0 ...
$ DCM : num 0 0 0 0 0 1 1 1 0 1 ...
$ HCM : num 0 0 0 0 0 0 0 0 1 0 ...
$ Myocard : num 0 0 0 0 0 0 0 0 0 0 ...
$ Other_Diag : num 0 0 0 0 0 0 0 0 0 0 ...
$ RCM : num 0 0 0 0 0 0 0 0 0 0 ...
$ VHD : num 0 0 0 0 0 0 0 0 0 0 ...
- Test
Hybrid CatBoost Model
Show the code
#| echo: false
#| warning: false
#| message: false
import numpy as np
# initialize Train and Test datasets
hybrid_X_train = r.hybrid_train_data
hybrid_y_train = r.hybrid_train_Y
hybrid_Y_train = np.array(hybrid_y_train)
hybrid_X_test = r.hybrid_test_data
hybrid_y_test = r.hybrid_test_Y
hybrid_Y_test = np.array(hybrid_y_test)
hybrid_cat_index = get_categorical_indexes(hybrid_X_train)Feature names are consistent between training and test datasets.
Optuna Hyperparameter Optimization
Model
- From Optuna Trial #36/50 with value: 0.74
Show the code
model_params_hybrid = {
'learning_rate': 0.03,
'depth': 4,
'colsample_bylevel': 0.12,
'min_data_in_leaf': 44,
'l2_leaf_reg': 5.66
}
Warning: Overfitting detector is active, thus evaluation metric is calculated on every iteration. 'metric_period' is ignored for evaluation metric.
0: test: 0.5891032 best: 0.5891032 (0) total: 26.7ms remaining: 26.7s
Stopped by overfitting detector (50 iterations wait)
bestTest = 0.7386785449
bestIteration = 255
<catboost.core.CatBoostClassifier object at 0x000002DFA58E5AF0>
Calibration Plot
Show the code
import pandas as pd
Y_Pred_hybrid = hybrid_model.predict(hybrid_X_test)
Y_Pred_Proba_hybrid = hybrid_model.predict_proba(hybrid_X_test)[:, 1] # get the probabilities of the positive class
Y_Pred_Proba_Positive_hybrid = hybrid_model.predict_proba(hybrid_X_test)[:, 1] # Probabilities of the positive class
Y_Pred_Proba_Negative_hybrid = hybrid_model.predict_proba(hybrid_X_test)[:, 0] # Probabilities of the negative class
# Converting predictions and actuals into a DataFrame for better readability, including negative class probabilities
hybrid_predictions = pd.DataFrame({
'Prob_Negative_Class': Y_Pred_Proba_Negative_hybrid,
'Prob_Positive_Class': Y_Pred_Proba_Positive_hybrid,
'Predicted': Y_Pred_hybrid,
'Actual': hybrid_y_test
})Show the code
hybrid_predictions <- py$hybrid_predictions %>%
mutate(Class = ifelse(Actual == 0, "survive", "not_survive"),
.pred_not_survive = Prob_Positive_Class
)
# Define the levels you want
factor_levels <- c("survive", "not_survive")
# Set the levels of the 'actuals' column
hybrid_predictions$Class <- factor(hybrid_predictions$Class, levels = rev(factor_levels))
hybrid_predictions %>%
cal_plot_logistic(Class, .pred_not_survive)Calibrated Model Metrics
Model AUC ... False Negative (FN) True Positive (TP)
0 Hybrid_Catboost 0.735938 ... 27 75
[1 rows x 13 columns]
Final Feature Importances
Show the code
# Retrieve the ranked comparison dataframe from Python
final_shap_df <- py$final_shap_df
# Convert the pandas dataframe to an R tibble
final_shap_tbl <- as_tibble(final_shap_df) %>%
arrange(desc(Importance))
# Create a formatted table using huxtable, including the ranks for each method and the 'Direction' column
final_shap_table <- final_shap_tbl %>%
rowid_to_column(var = "Overall Rank") %>%
select('Feature Id', 'Importance')
final_shap_table %>%
DT::datatable(
rownames = FALSE,
options = list(
columnDefs = list(
list(className = 'dt-center', targets = "_all")
)
)
)Cluster Function for Top Features
Show the code
# Function to cluster data based on optimal clusters
final_clustering <- function(data, optimal_k) {
# Perform k-means clustering with the optimal number of clusters
kmeans_res <- kmeans(as.matrix(data), centers = optimal_k, nstart = 25)
return(kmeans_res$cluster)
}
optimal_k <- 3
# Perform clustering using the optimal number of clusters
final_shap_df <- final_shap_df %>%
mutate(Cluster = final_clustering(select(., Importance), optimal_k))
# View the clustered data
final_shap_df %>%
DT::datatable(
rownames = FALSE,
options = list(
columnDefs = list(
list(className = 'dt-center', targets = "_all")
)
)
)‘ECMO_Reg’ is the cutoff feature based on WCSS (within-cluster sum of squares). However we can include a few more additional features that may be potential reasons for ‘Med_Refusals’ as we have several variables that are correlated. For this reason we will set final cutoff at ‘Txp_Volume’ or .02 value for Feature Importance.
CatBoost Model Accuracy Summary
Show the code
model_accuracy| Model | AUC | Brier Score | Accuracy | Log Loss | F1 Score | Precision | Recall | AUPR |
|---|---|---|---|---|---|---|---|---|
| Native_Catboost | 0.708 | 0.0874 | 0.672 | 0.308 | 0.29 | 0.186 | 0.657 | 0.222 |
| Hybrid_Catboost | 0.736 | 0.0839 | 0.618 | 0.294 | 0.282 | 0.174 | 0.735 | 0.294 |
Show the code
model_confusion_matrix| Model | True Negative (TN) | False Positive (FP) | False Negative (FN) | True Positive (TP) |
|---|---|---|---|---|
| Native_Catboost | 605 | 293 | 35 | 67 |
| Hybrid_Catboost | 543 | 355 | 27 | 75 |
SHAP Value Analysis
Show the code
feature_names = shap_values.feature_names
# Replace '_' with ' ' in each feature name
updated_feature_names = [name.replace('_', ' ') for name in feature_names]
shap_values.feature_names = updated_feature_namesMean Absolute Value Feature Importance
Show the code
library(ggplot2)
library(plotly)
library(dplyr)
# ggplot bar chart object
p <- ggplot(final_shap_df, aes(x = reorder(`Feature Id`, -Importance), y = Importance, fill = `Feature Id`)) +
geom_bar(stat = "identity") +
geom_hline(yintercept = 0.02, linetype = "dashed", color = "red") + # Add cutoff line
annotate("text", x = 33.5, y = 0.025, label = "Cutoff@0.02 (Txp_Volume)", color = "red", hjust = 0) + # Add annotation
labs(title = "Sorted Mean Absolute SHAP Values", x = "Features", y = "Mean Absolute SHAP Value") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "none") # Adjust text angle for better readability
# Interactive plotly object
p_interactive <- ggplotly(p, tooltip = c("x", "y")) # Hover effects with tooltips for both feature name and value
# Display the interactive plot
p_interactiveShow the code
# Save the interactive plot as HTML
htmlwidgets::saveWidget(p_interactive, "sorted_shap_values_interactive.html")